package mysqlgo

import (
	"fmt"
	"strings"
)

var (
	selectSQL = "SELECT%DISTINCT% %FIELD% FROM %TABLE%%JOIN%%WHERE%%GROUP%%HAVING%%ORDER%%LIMIT% %UNION%%COMMENT%"
	insertSQL = "INSERT INTO %TABLE%(%FIELD%) VALUE(%MARK%)"
	updateSQL = "UPDATE %TABLE% SET %FIELD% WHERE %ARGS%"
	deleteSQL = "DELETE FORM %TABLE% WHERE %AGRS%"
)

//SQLBuilder SQL生成工具类
type SQLBuilder struct {
	table      Table
	distinct   bool
	field      string
	join       []Join
	where      string
	whereArgs  []interface{}
	group      []string
	having     string
	order      map[string]bool
	limit      Limit
	union      Union
	comment    string
	page       string
	args       []interface{}
	lastSQL    string
	priviewSQL bool
}

//SQLTYPE SQL类型
type SQLTYPE int

const (
	query  SQLTYPE = 0 //查询
	insert SQLTYPE = 1 //插入
	update SQLTYPE = 2 //更新
	delete SQLTYPE = 3 //删除
)

func (sql *SQLBuilder) printSQL() {
	if sql.priviewSQL {
		fmt.Printf("[SQLBuilder Preview]: %s \n", sql.lastSQL)
	}
}

func (sql *SQLBuilder) tableFormat() string {
	if sql.table.Alias == "" {
		return fmt.Sprintf(" %s ", sql.table.Name)
	}
	return fmt.Sprintf(" %s as %s ", sql.table.Name, sql.table.Alias)
}

func (sql *SQLBuilder) distinctFormat() string {
	if sql.distinct {
		return "DISTINCT"
	}
	return ""
}

func (sql *SQLBuilder) fieldFormat() string {
	if sql.field != "" {
		return sql.field
	}
	return "*"
}

func (sql *SQLBuilder) joinFormat() string {
	var join []string
	for _, value := range sql.join {
		if !(strings.Index(value.Statement, "JOIN") > -1 && strings.Index(value.Statement, "join") > -1) {
			value.Statement = fmt.Sprintf(" JOIN %s ", value.Statement)
		}
		switch value.Type {
		case JOININNER:
			join = append(join, fmt.Sprintf(" INNER %s ", value.Statement))
			break
		case JOINLEFT:
			join = append(join, fmt.Sprintf(" LEFT %s ", value.Statement))
			break
		case JOINRIGHT:
			join = append(join, fmt.Sprintf(" RIGHT %s ", value.Statement))
			break
		case JOINFULL:
			join = append(join, fmt.Sprintf(" FULL %s ", value.Statement))
			break
		default:
			join = append(join, fmt.Sprintf(" INNER %s ", value.Statement))
			break
		}

	}
	return strings.Join(join, ",")
}

func (sql *SQLBuilder) whereFormat() string {
	if sql.where != "" {
		return fmt.Sprintf(" WHERE %s ", sql.where)
	}
	return ""
}

func (sql *SQLBuilder) groupFormat() string {
	if len(sql.group) > 0 {
		return fmt.Sprintf(" GROUP BY %s", strings.Join(sql.group, ","))
	}
	return ""
}

func (sql *SQLBuilder) havingFormat() string {
	if sql.having != "" {
		return fmt.Sprintf(" HAVING %s", sql.having)
	}
	return ""
}

func (sql *SQLBuilder) commentFormat() string {
	if sql.comment != "" {
		return fmt.Sprintf(" /* %s */", sql.comment)
	}
	return ""
}

func (sql *SQLBuilder) orderFormat() string {
	if len(sql.order) > 0 {
		var orderStr []string
		for key, order := range sql.order {
			var str string
			if order {
				str = fmt.Sprintf(" %s desc", key)
			} else {
				str = fmt.Sprintf(" %s asc", key)
			}
			orderStr = append(orderStr, str)
		}
		return fmt.Sprintf(" ORDER BY %s ", strings.Join(orderStr, ","))
	}
	return ""
}

func (sql *SQLBuilder) limitFormat() string {
	if sql.limit.RowCount > 0 {
		if sql.limit.Offset > 0 {
			return fmt.Sprintf(" LIMIT %d OFFSET %d ", sql.limit.RowCount, sql.limit.Offset)
		}
		return fmt.Sprintf(" LIMIT %d ", sql.limit.RowCount)
	}
	return ""
}

func (sql *SQLBuilder) unionFormat() string {
	if len(sql.union.SelectSQL) > 0 {
		if sql.union.All {
			return fmt.Sprintf(" UNION ALL %s ", strings.Join(sql.union.SelectSQL, ","))
		}
		return fmt.Sprintf(" UNION %s ", strings.Join(sql.union.SelectSQL, ","))
	}
	return ""
}

func (sql *SQLBuilder) fieldMarkFormat() string {
	if sql.field == "" || sql.field == "*" {
		return ""
	}
	var count = strings.Count(sql.field, ",")
	var fieldMark []string
	for i := 0; i <= count; i++ {
		fieldMark = append(fieldMark, "?")
	}
	if len(fieldMark) == 0 {
		return ""
	}
	return strings.Join(fieldMark, ",")
}

func (sql *SQLBuilder) parseQuerySQL() string {
	var sqlFormat = selectSQL
	sqlFormat = strings.Replace(sqlFormat, "%TABLE%", sql.tableFormat(), -1)
	sqlFormat = strings.Replace(sqlFormat, "%DISTINCT%", sql.distinctFormat(), -1)
	sqlFormat = strings.Replace(sqlFormat, "%FIELD%", sql.fieldFormat(), -1)
	sqlFormat = strings.Replace(sqlFormat, "%JOIN%", sql.joinFormat(), -1)
	sqlFormat = strings.Replace(sqlFormat, "%WHERE%", sql.whereFormat(), -1)
	sqlFormat = strings.Replace(sqlFormat, "%GROUP%", sql.groupFormat(), -1)
	sqlFormat = strings.Replace(sqlFormat, "%HAVING%", sql.havingFormat(), -1)
	sqlFormat = strings.Replace(sqlFormat, "%ORDER%", sql.orderFormat(), -1)
	sqlFormat = strings.Replace(sqlFormat, "%LIMIT%", sql.limitFormat(), -1)
	sqlFormat = strings.Replace(sqlFormat, "%UNION%", sql.unionFormat(), -1)
	sqlFormat = strings.Replace(sqlFormat, "%COMMENT%", sql.commentFormat(), -1)
	return sqlFormat
}

func (sql *SQLBuilder) parseInsertSQL() string {
	var sqlFormat = insertSQL
	sqlFormat = strings.Replace(sqlFormat, "%TABLE%", sql.tableFormat(), -1)
	sqlFormat = strings.Replace(sqlFormat, "%FIELD%", sql.fieldFormat(), -1)
	sqlFormat = strings.Replace(sqlFormat, "%DISTINCT%", sql.distinctFormat(), -1)
	fieldMark := sql.fieldMarkFormat()
	if fieldMark == "" {
		return ""
	}
	sqlFormat = strings.Replace(sqlFormat, "%MARK%", fieldMark, -1)
	return sqlFormat
}

func (sql *SQLBuilder) parseUpdateSQL(data map[string]interface{}) string {
	var sqlFormat = updateSQL
	sqlFormat = strings.Replace(sqlFormat, "%TABLE%", sql.tableFormat(), -1)

	var keys []string
	var values []interface{}
	for key, value := range data {
		keys = append(keys, fmt.Sprintf("%s = ?", key))
		values = append(values, value)
	}
	sqlFormat = strings.Replace(sqlFormat, "%FIELD%", strings.Join(keys, ","), -1)
	sqlFormat = strings.Replace(sqlFormat, "%ARGS%", sql.whereFormat(), -1)
	sql.args = append(sql.args, values...)
	sql.args = append(sql.args, sql.whereArgs...)
	return sqlFormat
}

func (sql *SQLBuilder) parseDeleteSQL() string {
	var sqlFormat = deleteSQL
	sqlFormat = strings.Replace(sqlFormat, "%TABLE%", sql.tableFormat(), -1)
	if len(sql.whereArgs) == 0 {
		return ""
	}
	sqlFormat = strings.Replace(sqlFormat, "%ARGS%", sql.whereFormat(), -1)
	return sqlFormat
}

//Where add where
func (sql *SQLBuilder) Where(condition string, args ...interface{}) {
	if sql.where != "" {
		sql.where = fmt.Sprintf("%s AND (%s)", sql.where, condition)
	} else {
		sql.where = fmt.Sprintf("(%s)", condition)
	}
	if args != nil {
		if sql.whereArgs != nil {
			sql.whereArgs = append(sql.whereArgs, args...)
		} else {
			sql.whereArgs = args
		}
	}
}

//Field 指定字段名称
func (sql *SQLBuilder) Field(fields ...string) {
	for _, field := range fields {
		if strings.Index(sql.field, field) < 0 {
			if sql.field == "" {
				sql.field = fmt.Sprintf("%s", field)
				continue
			}
			sql.field = fmt.Sprintf("%s, %s", sql.field, field)
		}
	}
}

//BuildSQL 构造SQL
func (sql *SQLBuilder) BuildSQL(t SQLTYPE, data map[string]interface{}) string {
	switch t {
	case query:
		sql.lastSQL = sql.parseQuerySQL()
		break
	case insert:
		sql.lastSQL = sql.parseInsertSQL()
		break
	case update:
		sql.lastSQL = sql.parseUpdateSQL(data)
		break
	case delete:
		sql.lastSQL = sql.parseDeleteSQL()
		break
	}
	sql.printSQL()
	return sql.lastSQL
}
